#!/usr/bin/env python3
# This file is part of Xpra.
# Copyright (C) 2019 Antoine Martin <antoine@xpra.org>
# Xpra is released under the terms of the GNU GPL v2, or, at your option, any
# later version. See the file COPYING for details.

import unittest
from gi.repository import GLib  # @UnresolvedImport

from xpra.util.objects import typedict, AdHocStruct


class ClientMixinTest(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        super(ClientMixinTest, cls).setUpClass()
        cls.glib = GLib
        cls.main_loop = cls.glib.MainLoop()

    def setUp(self):
        self.packets = []
        self.mixin = None
        self.packet_handlers = {}
        self.exit_codes = []

    def tearDown(self):
        unittest.TestCase.tearDown(self)
        if self.mixin:
            self.mixin.cleanup()
            self.mixin = None

    def stop(self):
        self.glib.timeout_add(1000, self.main_loop.quit)

    def debug_all(self):
        from xpra.log import enable_debug_for
        enable_debug_for("all")

    def dump_packets(self):
        from xpra.util.io import get_util_logger
        log = get_util_logger()
        log.info("dump_packets() %i packets to send:", len(self.packets))
        for x in self.packets:
            log.info("%s", x)


    def send(self, *args):
        self.packets.append(args)

    def get_packet(self, index):
        if index<0:
            actual_index = len(self.packets)+index
        else:
            actual_index = index
        assert actual_index>=0, "invalid actual index %i for index %i" % (actual_index, index)
        assert len(self.packets)>actual_index, "not enough packets (%i) to access %i" % (len(self.packets), index)
        return self.packets[actual_index]

    def verify_packet(self, index, expected):
        packet = self.get_packet(index)
        pslice = packet[:len(expected)]
        assert pslice==expected, "invalid packet slice %s, expected %s" % (pslice, expected)


    def add_packet_handler(self, packet_type, handler, _main_thread=True):
        self.packet_handlers[packet_type] = handler

    def add_packet_handlers(self, defs, _main_thread=True):
        self.packet_handlers.update(defs)

    def handle_packet(self, packet):
        packet_type = packet[0]
        ph = self.packet_handlers.get(packet_type)
        assert ph is not None, "no packet handler for %s" % packet_type
        ph(packet)


    def fake_quit(self, code):
        self.exit_codes.append(code)

    def _test_mixin_class(self, mclass, opts, caps=None, protocol_type="xpra"):
        x = self.mixin = mclass()
        x.server_packet_types = ()
        x.quit = self.fake_quit
        fake_protocol = AdHocStruct()
        fake_protocol.get_info = lambda : {}
        fake_protocol.set_compression_level = lambda _x : None
        fake_protocol.TYPE = protocol_type
        x._protocol = fake_protocol   # pylint: disable=protected-access
        x.add_packet_handlers = self.add_packet_handlers
        x.add_packet_handler = self.add_packet_handler
        x.idle_add = self.glib.idle_add
        x.timeout_add = self.glib.timeout_add
        x.source_remove = self.glib.source_remove
        x.init(opts)
        conn = AdHocStruct()
        conn.filename = "/fake/path/to/nowhere"
        x.setup_connection(conn)
        x.send = self.send
        x.send_now = self.send
        x.init_authenticated_packet_handlers()
        caps = self.make_caps(caps)
        x.parse_server_capabilities(caps)
        x.process_ui_capabilities(caps)
        assert x.get_caps() is not None
        assert x.get_info() is not None
        return x

    def make_caps(self, caps=None):
        return typedict(caps or {})
